# Credits Analysis: Core Analysis Functions
# Alex F. Gazmararian
# agazmararian@gmail.com

#' Standardize actor names for consistent analysis
#' @param actor Character vector of actor names
#' @return Character vector with standardized names
standardize_actor <- function(actor) {
  dplyr::case_when(
    actor == "senator" ~ "Senator",
    actor == "rep" ~ "Representative", 
    actor == "house rep" ~ "Representative",
    actor == "house representative" ~ "Representative",
    actor == "congressman" ~ "Representative",
    actor == "congresswoman" ~ "Representative",
    TRUE ~ actor
  )
}

#' Summarize statements by specified grouping variables
#' @param data Data frame containing statements
#' @param group_vars Character vector of grouping variables
#' @return Summarized data frame
summarize_statements <- function(data, group_vars) {
  data %>%
    dplyr::mutate(actor = ifelse(grepl("senator", actor), "senator", actor)) %>%
    dplyr::group_by(dplyr::across(dplyr::all_of(group_vars))) %>%
    dplyr::summarise(
      n_statements = dplyr::n(),
      .groups = "drop"
    ) %>%
    dplyr::arrange(dplyr::desc(n_statements))
}

#' Plot statements data with flexible aesthetics
#' @param summary_df Summarized data frame
#' @param x X-axis variable name
#' @param fill Fill variable name (optional)
#' @param title Plot title
#' @param ylab Y-axis label (optional)
#' @param fill_palette Custom fill palette (optional)
#' @return ggplot object
plot_statements <- function(summary_df, x, fill = NULL, title, ylab = NULL, fill_palette = NULL) {
  x_sym <- rlang::sym(x)
  fill_sym <- if (!is.null(fill)) rlang::sym(fill) else NULL
  
  p <- ggplot2::ggplot(summary_df, ggplot2::aes(x = !!x_sym, y = n_statements)) +
    ggplot2::geom_col(
      if (!is.null(fill_sym)) ggplot2::aes(fill = !!fill_sym) else NULL,
      color = "white", size = 0.1
    ) +
    ggplot2::labs(
      title = title,
      x = NULL,
      y = ylab %||% "Number of statements"
    ) +
    ggplot2::theme_minimal() +
    ggplot2::theme(
      axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)
    )
  
  if (!is.null(fill_palette)) {
    p <- p + ggplot2::scale_fill_manual(values = fill_palette)
  }
  
  return(p)
}

#' Set factor levels for a column in a data frame
#' @param df Data frame
#' @param col Column name
#' @param levels Character vector of factor levels
#' @return Data frame with updated factor levels
set_factor_levels <- function(df, col, levels) {
  df[[col]] <- factor(df[[col]], levels = levels)
  df
}

#' Calculate agreement metrics between two binary variables
#' @param x First binary variable (0/1 or TRUE/FALSE)
#' @param y Second binary variable (0/1 or TRUE/FALSE)
#' @param na.rm Logical, whether to remove NA values before calculating (default TRUE)
#' @return A list containing agreement metrics
calc_agreement <- function(x, y, na.rm = TRUE) {
  # Convert to numeric if logical
  x <- as.numeric(x)
  y <- as.numeric(y)
  
  # Handle NAs
  if (na.rm) {
    complete <- complete.cases(x, y)
    x <- x[complete]
    y <- y[complete]
  }
  
  # Calculate basic counts
  n_total <- length(x)
  n_agree <- sum(x == y)
  n_disagree <- sum(x != y)
  
  # Create contingency table
  cont_table <- table(x, y, dnn = c("x", "y"))
  
  # Calculate percentage agreement by category
  # For each value (0,1), what % of cases agree
  pct_by_cat <- list()
  for(val in sort(unique(c(x,y)))) {
    cases <- x == val | y == val
    if(sum(cases) > 0) {
      pct_by_cat[[as.character(val)]] <- sum(x[cases] == y[cases]) / sum(cases)
    }
  }
  
  # Return results
  list(
    pct_agreement = n_agree / n_total,
    n_agree = n_agree,
    n_disagree = n_disagree,
    n_total = n_total,
    table = cont_table,
    pct_by_category = pct_by_cat
  )
}

#' Print agreement metrics in a formatted way
#' @param agreement_results Output from calc_agreement function
print_agreement <- function(agreement_results) {
  message("Agreement Summary:")
  message(sprintf("Overall Agreement: %.1f%%", agreement_results$pct_agreement * 100))
  message(sprintf("Total Cases: %d", agreement_results$n_total))
  message(sprintf("Agreements: %d", agreement_results$n_agree))
  message(sprintf("Disagreements: %d", agreement_results$n_disagree))
  
  if (length(agreement_results$pct_by_category) > 0) {
    message("\nAgreement by Category:")
    for (cat in names(agreement_results$pct_by_category)) {
      message(sprintf("  Category %s: %.1f%%", cat, agreement_results$pct_by_category[[cat]] * 100))
    }
  }
}
